import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender
from recbole_cdr.model.cross_domain_recommender.graphcdr import GraphCDR
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
from hyperbolic_gnn.model.hgcn.layers.gnn import GAT,GCN,GATv2,GraphSAGE
from hyperbolic_gnn.model.hgcn.layers.euclidean_contrastive_learning import GraphContrastive
class Observation(GraphCDR):
    input_type = InputType.POINTWISE
    def __init__(self, config, dataset):
        super(Observation, self).__init__(config, dataset)
        # load dataset info
        self.config=config
        self.SOURCE_LABEL = dataset.source_domain_dataset.label_field
        self.TARGET_LABEL = dataset.target_domain_dataset.label_field
        self.device = config['device']
        self.latent_dim = config['embedding_size']
        self.n_layers = config['n_layers']
        self.reg_weight = config['reg_weight']
        if config['conv']=='lightgcn':
            self.source_norm_adj_matrix = self.get_norm_adj_mat(self.source_interaction_matrix, self.total_num_users,
                                                                self.total_num_items).to(self.device)
            self.target_norm_adj_matrix = self.get_norm_adj_mat(self.target_interaction_matrix, self.total_num_users,
                                                                self.total_num_items).to(self.device)
            self.merge_norm_adj_matrix=self.get_norm_adj_mat_a(self.source_interaction_matrix,
                                                              self.target_interaction_matrix).to(self.device)
        self.temp = config['temp']
        # 建一个总的图
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users,
                                                 embedding_dim=self.latent_dim)

        self.item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items,
                                                 embedding_dim=self.latent_dim)
        self.dropout = nn.Dropout(p=self.drop_rate)
        self.loss = nn.BCELoss()
        self.sigmoid = nn.Sigmoid()
        self.reg_loss = EmbLoss()
        self.source_user_degree_count = torch.from_numpy(self.source_interaction_matrix.sum(axis=1)).to(self.device)
        self.target_user_degree_count = torch.from_numpy(self.target_interaction_matrix.sum(axis=1)).to(self.device)
        self.source_item_degree_count = torch.from_numpy(self.source_interaction_matrix.sum(axis=0)).transpose(0, 1).to(self.device)
        self.target_item_degree_count = torch.from_numpy(self.target_interaction_matrix.sum(axis=0)).transpose(0, 1).to(self.device)
        self.target_restore_user_e = None
        self.target_restore_item_e = None
        self.apply(xavier_normal_initialization)
        self.other_parameter_name = ['target_restore_user_e', 'target_restore_item_e']
    def get_norm_adj_mat(self, interaction_matrix, n_users=None, n_items=None):
        # build adj matrix
        if n_users == None or n_items == None:
            n_users, n_items = interaction_matrix.shape
        A = sp.dok_matrix((n_users + n_items, n_users + n_items), dtype=np.float32)
        inter_M = interaction_matrix
        inter_M_t = interaction_matrix.transpose()
        data_dict = dict(zip(zip(inter_M.row, inter_M.col + n_users), [1] * inter_M.nnz))
        data_dict.update(dict(zip(zip(inter_M_t.row + n_users, inter_M_t.col), [1] * inter_M_t.nnz)))
        A._update(data_dict)
        # norm adj matrix
        sumArr = (A > 0).sum(axis=1)
        # add epsilon to avoid divide by zero Warning
        diag = np.array(sumArr.flatten())[0] + 1e-7
        diag = np.power(diag, -0.5)
        D = sp.diags(diag)
        L = D * A * D
        # covert norm_adj matrix to tensor
        L = sp.coo_matrix(L)
        row = L.row
        col = L.col
        i = torch.LongTensor([row, col])
        data = torch.FloatTensor(L.data)
        SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
        return SparseL
    def get_norm_adj_mat_a(self, interaction_matrix_s, interaction_matrix_t):
        # build adj matrix
        A = sp.dok_matrix(
            (self.total_num_users + self.total_num_items, self.total_num_users + self.total_num_items), dtype=np.float32
        )
        inter_S = interaction_matrix_s
        inter_S_t = interaction_matrix_s.transpose()
        inter_T = interaction_matrix_t
        inter_T_t = interaction_matrix_t.transpose()
        data_dict = dict(
            zip(zip(inter_S.row, inter_S.col + self.total_num_users), [1] * inter_S.nnz)
        )
        data_dict.update(
            dict(
                zip(
                    zip(inter_S_t.row + self.total_num_users, inter_S_t.col),
                    [1] * inter_S_t.nnz,
                )
            )
        )
        data_dict.update(
            dict(
                zip(
                    zip(inter_T.row, inter_T.col + self.total_num_users),
                    [1] * inter_T.nnz,
                )
            )
        )
        data_dict.update(
            dict(
                zip(
                    zip(inter_T_t.row + self.total_num_users, inter_T_t.col),
                    [1] * inter_T_t.nnz,
                )
            )
        )
        A._update(data_dict)
        # norm adj matrix
        sumArr = (A > 0).sum(axis=1)
        # add epsilon to avoid divide by zero Warning
        diag = np.array(sumArr.flatten())[0] + 1e-7
        diag = np.power(diag, -0.5)
        D = sp.diags(diag)
        L = D * A * D
        # covert norm_adj matrix to tensor
        L = sp.coo_matrix(L)
        row = L.row
        col = L.col
        i = torch.LongTensor(np.array([row, col]))
        data = torch.FloatTensor(L.data)
        SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
        return SparseL
    def lightgcn(self,all_embeddings,norm_adj_matrix):
        embeddings_list = [all_embeddings]
        for layer_idx in range(self.n_layers):
            all_embeddings = torch.sparse.mm(norm_adj_matrix, all_embeddings)
            embeddings_list.append(all_embeddings)
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
        return lightgcn_all_embeddings
    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings

    def forward(self):
        all_embeddings = self.get_ego_embeddings()
        all_embeddings =self.lightgcn(all_embeddings,
                                          self.merge_norm_adj_matrix)
        user_all_embeddings, item_all_embeddings = torch.split(all_embeddings,
                                                               [self.total_num_users,
                                                                self.total_num_items])
        return user_all_embeddings, item_all_embeddings






    def calculate_loss(self, interaction):
        self.init_restore_e()
        user_all_embeddings, item_all_embeddings = self.forward()
        losses = []
        source_user = interaction[self.SOURCE_USER_ID]
        source_item = interaction[self.SOURCE_ITEM_ID]
        source_label = interaction[self.SOURCE_LABEL]

        if self.config['mask_delete']==True:
           source_user,source_item,source_label=self.mask_delete(source_user,source_item,source_label)
        target_user = interaction[self.TARGET_USER_ID]
        target_item = interaction[self.TARGET_ITEM_ID]
        target_label = interaction[self.TARGET_LABEL]
        source_u_embeddings = user_all_embeddings[source_user]
        source_i_embeddings = item_all_embeddings[source_item]
        target_u_embeddings = user_all_embeddings[target_user]
        target_i_embeddings = item_all_embeddings[target_item]
        if self.config['setting']=='transfer':
            source_output = self.sigmoid(torch.mul(source_u_embeddings, source_i_embeddings).sum(dim=1))
            source_bce_loss = self.loss(source_output, source_label)
            # calculate Reg Loss in source domain
            u_source_ego_embeddings = self.source_user_embedding(source_user)
            i_source_ego_embeddings = self.source_item_embedding(source_item)
            source_reg_loss = self.reg_loss(u_source_ego_embeddings, i_source_ego_embeddings)
            losses = source_bce_loss + self.reg_weight * source_reg_loss
            return losses

        elif self.config['setting']=='merge':
            source_output = self.sigmoid(torch.mul(source_u_embeddings, source_i_embeddings).sum(dim=1))
            source_bce_loss = self.loss(source_output, source_label)
            u_source_ego_embeddings = self.source_user_embedding(source_user)
            i_source_ego_embeddings = self.source_item_embedding(source_item)
            source_reg_loss = self.reg_loss(u_source_ego_embeddings, i_source_ego_embeddings)
            source_loss = source_bce_loss + self.reg_weight * source_reg_loss
            losses.append(source_loss)
            # calculate BCE Loss in target domain
            target_output = self.sigmoid(torch.mul(target_u_embeddings, target_i_embeddings).sum(dim=1))
            target_bce_loss = self.loss(target_output, target_label)
            # calculate Reg Loss in target domain
            u_target_ego_embeddings = self.target_user_embedding(target_user)
            i_target_ego_embeddings = self.target_item_embedding(target_item)
            target_reg_loss = self.reg_loss(u_target_ego_embeddings, i_target_ego_embeddings)
            target_loss = target_bce_loss + self.reg_weight * target_reg_loss
            losses.append(target_loss)
            return tuple(losses)


        elif self.config['setting']=='source_passing':
            target_output = self.sigmoid(torch.mul(target_u_embeddings, target_i_embeddings).sum(dim=1))
            target_bce_loss = self.loss(target_output, target_label)
            # calculate Reg Loss in target domain
            u_target_ego_embeddings = self.target_user_embedding(target_user)
            i_target_ego_embeddings = self.target_item_embedding(target_item)
            target_reg_loss = self.reg_loss(u_target_ego_embeddings, i_target_ego_embeddings)
            losses = target_bce_loss + self.reg_weight * target_reg_loss
            return losses


    def predict(self, interaction):
        result = []
        _, _, target_user_all_embeddings, target_item_all_embeddings = self.forward()
        user = interaction[self.TARGET_USER_ID]
        item = interaction[self.TARGET_ITEM_ID]
        u_embeddings = target_user_all_embeddings[user]
        i_embeddings = target_item_all_embeddings[item]
        scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
        return scores

    def full_sort_predict(self, interaction):
        user = interaction[self.TARGET_USER_ID]
        restore_user_e, restore_item_e = self.get_restore_e()
        u_embeddings = restore_user_e[user]
        i_embeddings = restore_item_e[:self.target_num_items]
        scores = torch.matmul(u_embeddings, i_embeddings.transpose(0, 1))
        return scores.view(-1)

    def init_restore_e(self):
        # clear the storage variable when training
        if self.target_restore_user_e is not None or self.target_restore_item_e is not None:
            self.target_restore_user_e, self.target_restore_item_e = None, None

    def get_restore_e(self):
        if self.target_restore_user_e is None or self.target_restore_item_e is None:
            self.target_restore_user_e, self.target_restore_item_e = self.forward()
        return self.target_restore_user_e, self.target_restore_item_e